The MNIST dataset is a popular dataset for data science and machine learning tasks. It consists of handwritten digits.
We’ll be working with a subset of 1000 images from this dataset.
library(tidyverse)
## Warning: replacing previous import 'vctrs::data_frame' by 'tibble::data_frame'
## when loading 'dplyr'
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.2 ✓ purrr 0.3.4
## ✓ tibble 3.0.4 ✓ dplyr 1.0.1
## ✓ tidyr 1.1.0 ✓ stringr 1.4.0
## ✓ readr 1.3.1 ✓ forcats 0.5.0
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(plotly)
##
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
##
## last_plot
## The following object is masked from 'package:stats':
##
## filter
## The following object is masked from 'package:graphics':
##
## layout
mnist <- read_csv('data/mnist.csv', n_max = 5000)
## Parsed with column specification:
## cols(
## .default = col_double()
## )
## See spec(...) for full column specifications.
mnist_umap <- readRDS('data/mnist_umap.RDS')
mnist_umap %>%
as_tibble() %>%
ggplot(aes(x = V1, y = V2)) +
geom_point(size = 0.25)
## Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if `.name_repair` is omitted as of tibble 2.0.0.
## Using compatibility `.name_repair`.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
First, let’s cluster the mnist data. Let’s first try using k-means.
set.seed(123)
max_clusters = 20
ss = vector(length = max_clusters)
for (i in 1:max_clusters){
clusters <- mnist_umap %>%
kmeans(centers = i)
ss[i] <- clusters$tot.withinss
}
wss <- tibble(clusters = 1:max_clusters, wss = ss)
wss %>% ggplot() +
geom_point(aes(x = clusters, y = wss)) +
geom_line(aes(x = clusters, y = wss)) +
scale_x_continuous(breaks = 1:max_clusters) +
labs(title = "K-Means Clustering Results", x = "Number of Clusters", y = "Within-Cluster Total Sum of Squares")
Based on the plot, it appears that 4 might be a good number of clusters to use.
mnist_kmeans <- mnist_umap %>%
kmeans(centers = 4)
mnist_umap %>%
as_tibble() %>%
mutate(cluster = as.factor(mnist_kmeans$cluster)) %>%
ggplot(aes(x = V1, y = V2, color = cluster)) +
geom_point(size = 0.25) +
guides(colour = guide_legend(override.aes = list(size=2)))
However, it appears that there are at least 6 distinct clusters.
set.seed(123)
mnist_kmeans <- mnist_umap %>%
kmeans(centers = 6)
mnist_umap %>%
as_tibble() %>%
mutate(cluster = as.factor(mnist_kmeans$cluster)) %>%
ggplot(aes(x = V1, y = V2, color = cluster)) +
geom_point(size = 0.25) +
guides(colour = guide_legend(override.aes = list(size=2)))
But, k-means does not do a good job of identifying what look like the natural clusters in our dataset. This has to do with the fact that k-means is looking for identially-sized spherical clusters.
Let’s look at an alternative approach - density-based clustering. Rather that looking for spherical clusters, density-based clustering looks for high-density areas separated by lower-density areas.
We’ll specifically be using HDBSCAN, which is a hierarchical clustering method. We do not have to identify the number of clusters ahead of time. Rather, this is decided by the algorithm.
library(dbscan)
set.seed(123)
mnist_dbscan <- mnist_umap %>%
hdbscan(minPts = 30)
mnist_umap %>%
as_tibble() %>%
mutate(cluster = as.factor(mnist_dbscan$cluster)) %>%
ggplot(aes(x = V1, y = V2, color = cluster)) +
geom_point(size = 0.25) +
guides(colour = guide_legend(override.aes = list(size=2)))
We end up with what looks to be a more sensible clustering.
Another advantage of DBSCAN is that it can identify potential outlier points. K-means must assign every point to a cluster; whereas, DBSCAN can label points as outliers. It does this on the basis of a GLOSH score, which compares the density at each observation to the nearby density.
p <- mnist_umap %>%
as_tibble() %>%
mutate(outlier_score = mnist_dbscan$outlier_scores) %>%
ggplot(aes(x = V1, y = V2, color = outlier_score)) +
geom_point(size = 0.5) +
scale_color_viridis_c()
ggplotly(p)
We can highlight the most likely outliers.
potential_outliers <- (-mnist_dbscan$outlier_scores %>% order())[1:10]
p <- mnist_umap %>%
as_tibble() %>%
mutate(cluster = as.factor(mnist_kmeans$cluster)) %>%
ggplot(aes(x = V1, y = V2)) +
geom_point(size = 0.25) +
geom_point(data = mnist_umap[potential_outliers, ] %>% as_tibble(), aes(x = V1, y = V2), fill = 'orange', size = 2, shape = 21)
ggplotly(p)
Finally, we can see what these potential outliers look like.
im <- matrix(mnist[potential_outliers[1],] %>% select(-c(label)) %>% unlist(), nrow = 28, byrow = FALSE)
for (i in 1:nrow(im)){
im[i,] = rev(im[i,])
}
par(pty = "s")
image(1:28, 1:28, im, col = gray(rev((0:255)/255)))
There are other metrics useful for detecting outliers, including k-nearest neighbors deistance, local outlier factor (lof) and isolation scores, based on isolation forests.